{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Table of Contents\n",
    "* [Intro](#Intro)\n",
    "\t* [Style Transfer](#Style-Transfer)\n",
    "* [Load Data](#Load-Data)\n",
    "* [Recreate Input](#Recreate-Input)\n",
    "* [Recreate Style](#Recreate-Style)\n",
    "* [Style Transfer](#Style-Transfer)\n",
    "\t* [Different Approach (TOFIX)](#Different-Approach-%28TOFIX%29)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Intro"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Exploratory notebook related to the theory and concepts behind Style Transfer using CNNs. Includes toy examples implementation and visualization.\n",
    "\n",
    "([FastAI - Lesson 8](http://course.fast.ai/lessons/lesson8.html))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Style Transfer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "About the generation of new images by weighted combination of a target visual style and a target semantic content. The process tries to optimize both style and content by refining the input data; it generally uses information extracted from internal layer of an already trained CNN to obtain a representation of the style component."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-09T08:41:18.649737",
     "start_time": "2017-10-09T08:41:14.763514Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "\n",
    "import time\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "import keras\n",
    "from keras import backend as K\n",
    "from keras.models import Model\n",
    "from keras import metrics\n",
    "from keras.applications.vgg16 import VGG16\n",
    "\n",
    "import scipy\n",
    "from scipy.optimize import fmin_l_bfgs_b\n",
    "from scipy.misc import imsave\n",
    "\n",
    "#backend.set_image_data_format('channels_last')\n",
    "#keras.backend.set_image_dim_ordering('tf')\n",
    "\n",
    "import os\n",
    "import sys\n",
    "sys.path.append(os.path.join(os.getcwd(), os.pardir))\n",
    "\n",
    "from utils.vgg_utils import preprocess, deprocess, gram_matrix\n",
    "from utils.vgg16_avg import VGG16_Avg\n",
    "\n",
    "RES_DIR = os.path.join('resources')\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-09T08:41:21.589905",
     "start_time": "2017-10-09T08:41:21.008872Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "height = 256\n",
    "width = 256"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-09T08:41:22.359949",
     "start_time": "2017-10-09T08:41:21.591905Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# load content image\n",
    "content_image = None\n",
    "with Image.open(os.path.join(RES_DIR, 'superman.jpg')) as img:\n",
    "    img = img.resize((height, width))\n",
    "    content_image = np.asarray(img, dtype='float32')\n",
    "    plt.imshow(img.convert(mode='RGB'))\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-09T08:49:01.240195",
     "start_time": "2017-10-09T08:49:00.445150Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# load style image\n",
    "style_image = None\n",
    "with Image.open(os.path.join(RES_DIR, 'comics_style.jpg')) as img:\n",
    "    img = img.resize((height, width))\n",
    "    style_image = np.asarray(img, dtype='float32')\n",
    "    plt.imshow(img.convert(mode='RGB'))\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-08T15:57:43.964016",
     "start_time": "2017-10-08T15:57:42.766948Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "content_image.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Recreate Input"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this first step I am going to simply recreate an image from noise using the content loss."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-08T16:45:41.290590",
     "start_time": "2017-10-08T16:45:40.725557Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# define input image\n",
    "img_arr = preprocess(np.expand_dims(style_image, axis=0))\n",
    "#img_arr = preproc(np.expand_dims(np.array(Image.open(os.path.join(RES_DIR, 'simpsons_style.jpg'))), axis=0))\n",
    "shp = img_arr.shape\n",
    "\n",
    "print(shp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-09T08:39:19.238907",
     "start_time": "2017-10-09T08:39:16.239735Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# get VGG model\n",
    "model = VGG16(include_top=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-08T16:45:45.046805",
     "start_time": "2017-10-08T16:45:43.850736Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# define layer model (VGG model input and intermediate layer output)\n",
    "layer = model.get_layer('block5_conv1').output\n",
    "layer_model = Model(model.input, layer)\n",
    "targ = K.variable(layer_model.predict(img_arr))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-08T16:45:46.374880",
     "start_time": "2017-10-08T16:45:45.583835Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# define our loss and gradients\n",
    "loss = metrics.mse(layer, targ)\n",
    "grads = K.gradients(loss, model.input)\n",
    "fn = K.function([model.input], [loss]+grads)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-09T08:41:27.651252",
     "start_time": "2017-10-09T08:41:27.059218Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# utility function to hold loss and gradients\n",
    "class Evaluator(object):\n",
    "    def __init__(self, f, shp): self.f, self.shp = f, shp\n",
    "        \n",
    "    def loss(self, x):\n",
    "        loss_, self.grad_values = self.f([x.reshape(self.shp)])\n",
    "        return loss_.astype(np.float64)\n",
    "\n",
    "    def grads(self, x): return self.grad_values.flatten().astype(np.float64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-08T16:45:48.286990",
     "start_time": "2017-10-08T16:45:47.713957Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "evaluator = Evaluator(fn, shp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-09T08:41:28.312289",
     "start_time": "2017-10-09T08:41:27.726256Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# run optimization process and save result image at each iteration\n",
    "def solve_image(eval_obj, iterations, x, img_shape, dest_dir=''):\n",
    "    for i in range(iterations):\n",
    "        start_time = time.time()\n",
    "        x, min_val, info = fmin_l_bfgs_b(eval_obj.loss, x.flatten(),\n",
    "                                         fprime=eval_obj.grads, maxfun=20)\n",
    "        x = np.clip(x, -127,127)\n",
    "        print('Current loss value:', min_val)\n",
    "        end_time = time.time()\n",
    "        print('Iteration {} completed in {:.1f}s'.format(i, end_time - start_time))\n",
    "        img = deprocess(x.copy(), img_shape)[0]\n",
    "        img_filepath = os.path.join(dest_dir, \"res_at_iteration_{}.png\".format(i))\n",
    "        imsave(img_filepath, img)\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-08T16:46:10.042234",
     "start_time": "2017-10-08T16:46:09.357195Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "x = np.random.uniform(-2.5, 2.5, shp)\n",
    "#x = np.random.uniform(0, 255, shp) - 128.\n",
    "plt.imshow(x[0]);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-08T16:48:12.597244",
     "start_time": "2017-10-08T16:46:13.455429Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "x = solve_image(evaluator, 5, x, 'recreate_input')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-08T12:07:38.322356",
     "start_time": "2017-10-08T12:07:36.999281Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "plt.imshow(deproc(x,shp)[0].astype('uint8'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Recreate Style"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "While in previous section we recreated the input from noise, here we are actually recreating the style from noise."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-08T16:39:57.303915",
     "start_time": "2017-10-08T16:39:56.727882Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# load and process input content\n",
    "style_arr = preprocess(np.expand_dims(style_image, axis=0)[:,:,:,:3])\n",
    "shp = style_arr.shape\n",
    "\n",
    "print(shp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-08T16:41:13.874294",
     "start_time": "2017-10-08T16:41:12.607222Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# get VGG model\n",
    "#model = VGG16(include_top=False, pooling='avg', input_shape=shp[1:]) #input_tensor=input_tensor\n",
    "model = VGG16_Avg(include_top=False, input_shape=shp[1:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-08T16:13:13.617189",
     "start_time": "2017-10-08T16:13:12.394119Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-08T16:10:33.232016",
     "start_time": "2017-10-08T16:10:31.970943Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-08T16:41:14.440327",
     "start_time": "2017-10-08T16:41:13.875294Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "outputs = {l.name: l.output for l in model.layers}\n",
    "layers = [outputs['block{}_conv1'.format(o)] for o in range(1,3)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-08T16:41:15.310376",
     "start_time": "2017-10-08T16:41:14.442327Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "layers_model = Model(model.input, layers)\n",
    "targs = [K.variable(o) for o in layers_model.predict(style_arr)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-09T08:41:38.248858",
     "start_time": "2017-10-09T08:41:37.651824Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def style_loss(x, targ):\n",
    "    return metrics.mse(gram_matrix(x), gram_matrix(targ))\n",
    "    #S = gram_matrix(style)\n",
    "    #C = gram_matrix(combination)\n",
    "    #channels = 3\n",
    "    #size = height * width\n",
    "    #return K.sum(K.square(S - C)) / (4. * (channels ** 2) * (size ** 2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-08T16:41:16.732458",
     "start_time": "2017-10-08T16:41:15.887409Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "loss = sum(style_loss(l1[0], l2[0]) for l1,l2 in zip(layers, targs))\n",
    "grads = K.gradients(loss, model.input)\n",
    "style_fn = K.function([model.input], [loss]+grads)\n",
    "evaluator = Evaluator(style_fn, shp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-08T16:40:35.517100",
     "start_time": "2017-10-08T16:40:34.948068Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "rand_img = lambda shape: np.random.uniform(-2.5, 2.5, shape)/1\n",
    "x = rand_img(shp)\n",
    "#x = scipy.ndimage.filters.gaussian_filter(x, [0,2,2,0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-08T16:40:36.587162",
     "start_time": "2017-10-08T16:40:35.901122Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "plt.imshow(x[0]);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-08T16:43:39.962650",
     "start_time": "2017-10-08T16:41:18.927583Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "iterations=10\n",
    "x = rand_img(shp)\n",
    "x = solve_image(evaluator, iterations, x, folder_name='recreate_style')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Style Transfer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we are finally use both the content and style images to operate the style transfer task."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-09T08:49:07.881575",
     "start_time": "2017-10-09T08:49:07.277541Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# load and process input content\n",
    "content_arr = preprocess(np.expand_dims(content_image, axis=0))\n",
    "style_arr = preprocess(np.expand_dims(style_image, axis=0))\n",
    "shp = content_arr.shape\n",
    "\n",
    "print(content_arr.shape)\n",
    "print(style_arr.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-09T08:49:09.670678",
     "start_time": "2017-10-09T08:49:08.478609Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# get VGG model\n",
    "# later versions of Keras use pooling='avg'\n",
    "model = VGG16(include_top=False, input_shape=shp[1:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-09T08:49:10.266712",
     "start_time": "2017-10-09T08:49:09.671678Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "outputs = {l.name: l.output for l in model.layers}\n",
    "style_layers = [outputs['block{}_conv2'.format(o)] for o in range(1,6)]\n",
    "content_name = 'block4_conv2'\n",
    "content_layer = outputs[content_name]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-09T08:49:11.332773",
     "start_time": "2017-10-09T08:49:10.268712Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "style_model = Model(model.input, style_layers)\n",
    "style_targs = [K.variable(o) for o in style_model.predict(style_arr)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-09T08:49:12.482838",
     "start_time": "2017-10-09T08:49:11.334773Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "content_model = Model(model.input, content_layer)\n",
    "content_targ = K.variable(content_model.predict(content_arr))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-09T08:49:13.107874",
     "start_time": "2017-10-09T08:49:12.484839Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "style_wgts = [0.05,0.2,0.2,0.25,0.3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-09T08:49:14.718966",
     "start_time": "2017-10-09T08:49:13.110874Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "loss = sum(style_loss(l1[0], l2[0])*w\n",
    "           for l1,l2,w in zip(style_layers, style_targs, style_wgts))\n",
    "loss += metrics.mse(content_layer, content_targ)/2\n",
    "grads = K.gradients(loss, model.input)\n",
    "transfer_fn = K.function([model.input], [loss]+grads)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-09T08:49:15.302000",
     "start_time": "2017-10-09T08:49:14.720966Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "evaluator = Evaluator(transfer_fn, shp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-09T08:49:16.683079",
     "start_time": "2017-10-09T08:49:15.947037Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "iterations=10\n",
    "x = np.random.uniform(-2.5, 2.5, shp)\n",
    "plt.imshow(x[0]);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-09T08:51:32.723860",
     "start_time": "2017-10-09T08:49:17.333116Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "x = solve_image(evaluator, iterations, x, shp, dest_dir=os.path.join('results', 'style_transfer'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Different Approach (TOFIX)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See [Keras example](https://github.com/fchollet/keras/blob/master/examples/neural_style_transfer.py)\n",
    "\n",
    "Feed concatenation of images directly to the network. The previous approach builds two different models and combines their loss."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "feature_layers = ['block1_conv2', 'block2_conv2',\n",
    "                  'block3_conv3', 'block4_conv3',\n",
    "                  'block5_conv3']\n",
    "for layer_name in feature_layers:\n",
    "    layer_features = layers[layer_name]\n",
    "    style_features = layer_features[1, :, :, :]\n",
    "    combination_features = layer_features[2, :, :, :]\n",
    "    sl = style_loss(style_features, combination_features)\n",
    "    loss += (style_weight / len(feature_layers)) * sl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-06-22T12:40:22.139036",
     "start_time": "2017-06-22T12:40:22.091033Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "content_image = backend.variable(content_array)\n",
    "style_image = backend.variable(style_array)\n",
    "combination_image = backend.placeholder((1, height, width, 3))\n",
    "#if backend.image_data_format() == 'channels_first':\n",
    "#    combination_image = backend.placeholder((1, 3, height, width))\n",
    "#else:\n",
    "#    combination_image = backend.placeholder((1, height, width, 3))\n",
    "\n",
    "input_tensor = backend.concatenate([content_image,\n",
    "                                    style_image,\n",
    "                                    combination_image], axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-06-22T12:40:34.350735",
     "start_time": "2017-06-22T12:40:34.340734Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "content_weight = 0.025\n",
    "style_weight = 5.0\n",
    "total_variation_weight = 1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-06-22T12:40:35.279788",
     "start_time": "2017-06-22T12:40:35.263787Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "loss = backend.variable(0.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-06-22T12:40:35.748815",
     "start_time": "2017-06-22T12:40:35.724813Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "layer_features = layers['block2_conv2']\n",
    "content_image_features = layer_features[0, :, :, :]\n",
    "combination_features = layer_features[2, :, :, :]\n",
    "\n",
    "loss += content_weight * content_loss(content_image_features,\n",
    "                                      combination_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-06-22T12:40:38.343963",
     "start_time": "2017-06-22T12:40:38.285960Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def total_variation_loss(x):\n",
    "    a = backend.square(x[:, :height-1, :width-1, :] - x[:, 1:, :width-1, :])\n",
    "    b = backend.square(x[:, :height-1, :width-1, :] - x[:, :height-1, 1:, :])\n",
    "    return backend.sum(backend.pow(a + b, 1.25))\n",
    "\n",
    "loss += total_variation_weight * total_variation_loss(combination_image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-06-22T12:40:39.278016",
     "start_time": "2017-06-22T12:40:38.526973Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "grads = backend.gradients(loss, combination_image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-06-22T12:40:39.324019",
     "start_time": "2017-06-22T12:40:39.279016Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "outputs = [loss]\n",
    "outputs += grads\n",
    "f_outputs = backend.function([combination_image], outputs)\n",
    "\n",
    "def eval_loss_and_grads(x):\n",
    "    x = x.reshape((1, height, width, 3))\n",
    "    outs = f_outputs([x])\n",
    "    loss_value = outs[0]\n",
    "    if len(outs[1:]) == 1:\n",
    "        grad_values = outs[1].flatten().astype('float64')\n",
    "    else:\n",
    "        grad_values = np.array(outs[1:]).flatten().astype('float64')\n",
    "    return loss_value, grad_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-06-22T12:40:41.001115",
     "start_time": "2017-06-22T12:40:40.969113Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "x = np.random.uniform(0, 255, (1, height, width, 3)) - 128."
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [conda env:tensorflow-gpu]",
   "language": "python",
   "name": "conda-env-tensorflow-gpu-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  },
  "toc": {
   "colors": {
    "hover_highlight": "#DAA520",
    "running_highlight": "#FF0000",
    "selected_highlight": "#FFD700"
   },
   "moveMenuLeft": true,
   "nav_menu": {
    "height": "12px",
    "width": "252px"
   },
   "navigate_menu": true,
   "number_sections": true,
   "sideBar": true,
   "threshold": 4,
   "toc_cell": false,
   "toc_section_display": "block",
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
